#ifndef _PARSER_
#define _PARSER_

#pragma once
#include <vector>
#include <iostream>
#include <fstream>
#include <iomanip>
#include <set>
using namespace std;

#include "IOPipe.h"
#include "FGen.h"
#include "Decoder.h"

#include "Parameters.h"
#include "common.h"
#include "GzFile.h"

/*******************
There seems some conflicts between "ChartUtils.h" and "spthread.h".
The order of their #include can not be reversed!
I do not know why!
 *******************/
#include "CharUtils.h"
#include "StringMap.h"
using namespace egstra;

#include "spthread.h"
#include "threadpool.h"
/*******************/

#include "lbfgs-crfpp.h" // crf++

namespace dparser {
	/*
	this class controls the parsing process.
	*/
	class Parser
	{
	public:
		IOPipe m_pipe_train;
		IOPipe m_pipe_train2;
		IOPipe m_pipe_train3;
		IOPipe m_pipe_test;
		IOPipe m_pipe_dev;
		IOPipe m_pipe_dev2;
		FGen m_fgen;
		parameters m_param;
		Decoder *m_decoder;

/* options */
	private:
		int _display_interval;

		string _dictionary_path;
		string _parameter_path;
		int _inst_max_len_to_throw;

		bool _train;
		int _iter_num;
        int _best_iter_num_so_far_a;
        double _best_accuracy_a;
        int _best_iter_num_so_far_b;
        double _best_accuracy_b;

		vector<int> _inst_idx_to_read;
		
		bool _self_training;
		bool _use_train2;
		bool _use_train3;
		int _inst_num_from_train2_one_iter;
		int _inst_num_from_train3_one_iter;
		int _inst_num_from_train1_one_iter;
		string _filename_train2;
		string _filename_train3;
		int _inst_max_num_train2;
		int _inst_max_num_train3;

		bool _use_dev2;
		string _filename_train;
		string _filename_dev;
		string _filename_dev2;

		int _inst_max_num_train;
		bool _dictionary_exist;
		bool _pamameter_exist;
		int _param_tmp_num;

		bool _test;
		string _filename_test;
		string _filename_output;
		int _param_num_for_eval;
		int _inst_max_num_eval;
		int _test_batch_size;

		bool _verify_decoding_algorithm;
		
		int _thread_num;

		/* thread control */
		threadpool _tp;
		static sp_thread_mutex_t _mutex;
		static sp_thread_cond_t _cond_waiting_create_feat;	// waiting for the decoding-thread
		static sp_thread_cond_t _cond_waiting_update;	// waiting for the creating-features-thread to finish the current instance.
		static sp_thread_cond_t _cond_done_update;		// complete all the instances

		static vector<bool> _train_features_created;
		static int _train_create_feat_inst_i;
		static int _train_update_inst_i;

		floatval_t _l2sgd_calibration_init_loss;
		static double _sum_loss, _t0, _t, _lambda, _eta, _decay, _gain;
		static vector<double> _g;
		static bool _mbr_decoding;
		static bool _test_tag_filter;
		static double _test_tag_filter_lambda;

		ofstream _of_tag_filter_prob;
		int tot_word_num;
		int tot_tag_num;
		int tot_correct_tag_num;
		void initialize_filter_stat() {
			tot_word_num = 0;
			tot_tag_num = 0;
			tot_correct_tag_num = 0;
		}
		void output_filter_stat()
		{
			cerr << " joint-tag filter results: " << endl;
			fprintf(stderr, "oracle POS tagging accuracy: %d / %d = %.2f\n", 
				tot_correct_tag_num, tot_word_num, 100.0*tot_correct_tag_num/tot_word_num);
			fprintf(stderr, "average tag num per word: %d / %d = %.2f \n", 
				tot_tag_num, tot_word_num, 1.0*tot_tag_num/tot_word_num);
		}

		void evaluate_output_tag_filter(Instance *inst) {
			const int len = inst->size();
			tot_word_num += (len - 1);
			for (int wi = 1; wi < len; ++wi) {
				const string &gold = inst->cpostags[wi];
				vector<string> vecgold;
				simpleTokenize(gold, vecgold, "^");
				assert(vecgold.size() == 2);
				assert(vecgold[0] != "*" || vecgold[1] != "*");
				assert(vecgold[0] == "*" || vecgold[1] == "*");


				tot_tag_num += inst->filtered_tags[wi].size();
				ostringstream os_head_list;
				ostringstream os_prob_list;
				//os_prob_list.precision(15);
				bool correct_exist = false;	
				for (int ti = 0; ti < inst->filtered_tags[wi].size(); ++ti) {
					os_head_list << (ti == 0 ? "" : "_")
						<< inst->filtered_tags[wi][ti];
					os_prob_list << (ti == 0 ? "" : "_")
						<< inst->filtered_probs[wi][ti];

					const string &sys = inst->filtered_tags[wi][ti];
					vector<string> vecsys;
					simpleTokenize(sys, vecsys, "^");
					assert(vecsys.size() == 2);
					if (vecsys[0] == vecgold[0] || vecsys[1] == vecgold[1]) {
						correct_exist = true;
					}
				}
				if (correct_exist) ++tot_correct_tag_num;
				_of_tag_filter_prob << os_prob_list.str() << endl;
				inst->pdeprels[wi] = os_head_list.str();
			}
			_of_tag_filter_prob << endl;
		}

		/* variables used in evaluate */
		int inst_num_processed_total;
		int noov;
		int noov_correct;
		int nword;
		int ncorrect_joint;
		int ncorrect_a_max;
		int ncorrect_b_max;
		//int ncorrect_a_sum;
		int ncorrect_b_sum;

		string _train_method;
		static bool _gradient_update_allow_conflict;

		typedef struct {
			int			batch_size;				
			floatval_t  c2;						// Coefficient for L2 regularization
			//int         max_iterations;			// The maximum number of iterations (epochs) for SGD optimization
			int         period;					// The duration of iterations to test the stopping criterion.
			floatval_t  delta;					/** The threshold for the stopping criterion; an optimization process stops when
												the improvement of the log likelihood over the last ${period} iterations is no
												greater than this threshold.*/
			floatval_t  calibration_eta;		// The initial value of learning rate (eta) used for calibration
			floatval_t  calibration_rate;		// The rate of increase/decrease of learning rate for calibration.
			int         calibration_samples;	// The number of instances used for calibration
			int         calibration_candidates;	// The number of candidates of learning rate.
			int         calibration_max_trials;	// The maximum number of trials of learning rates for calibration

			floatval_t  lambda;					
			floatval_t  t0;	
		} l2sgd_training_option_t;
		l2sgd_training_option_t _l2sgd_opt;
		
				/** From CRF++ documents:
			-a CRF-L2 or CRF-L1:
				Changing the regularization algorithm. Default setting is L2. 
				Generally speaking, L2 performs slightly better than L1, 
				while the number of non-zero features in L1 is drastically smaller 
				than that in L2. 
			-c float: 
				With this option, you can change the hyper-parameter for the CRFs. 
				With larger C value, CRF tends to overfit to the give training corpus.
				This parameter trades the balance between overfitting and underfitting. 
				The results will significantly be influenced by this parameter. 
				You can find an optimal value by using held-out data or more general model selection method such as cross validation. 
		*/
		typedef struct {
			bool		isL2;							// L1 or L2	
			double		c;								// Coefficient for L1 or L2 regularization
			double		eta;							// set FLOAT for termination criterion(default 0.0001)
			int			shrinking_size;					// set INT for number of iterations variable needs to 
														// be optimal before considered for shrinking. (default 20)
		} lbfgs_crfpp_train_opt_t;
		lbfgs_crfpp_train_opt_t _lbfgs_crfpp_opt;


	public:
		Parser() : m_decoder(0), _tp(0) {
			process_options();
			_tp = create_threadpool(max(1, _thread_num));
		}

		~Parser(void) {
			delete_decoder(m_decoder);
			destroy_threadpool(_tp);
			_tp = 0;
		}

		void process_options();

		void run()
		{
			if (_train) {
				pre_train();
				if (_train_method == "lbfgs-crfpp")
					train_lbfgs_crfpp();
				else if (_train_method == "l2sgd")
					train_l2sgd();
				//else if (_train_method == "pa")
				//	train_passive_aggressive();
				else {
					cerr << "unknown train method: " << _train_method << endl;
					exit(-1);
				}
				post_train();
			}
			if (_test) test(_param_num_for_eval);
		}

		static Decoder *new_decoder() {
			Decoder *decoder = new Decoder();
			assert(decoder);
			decoder->process_options();
			return decoder;
		}

		static void delete_decoder(Decoder *&decoder) {
			if (decoder) {
				delete decoder;
				decoder = 0;
			}
		}

	private:
		typedef struct thread_arg_t {
			thread_arg_t(Parser * const parser, const int inst_num, Instance * const inst=0, const int inst_idx = -1, bool is_test=false)
				: _parser(parser), _inst_num(inst_num), _inst(inst), _inst_idx(inst_idx), _is_test(is_test) {}
			Parser * const _parser;
			const int _inst_num;
			Instance * const _inst;
			const int _inst_idx;
			const bool _is_test;
		} ;

		static void parse_one_inst_thread( void *arg );
		static void train_update_one_inst_thread( void *arg );
		static void l2sgd_calibration_compute_init_loss( void *arg );

		double update_weights_or_gradients_with_gold_tree(const Instance *const inst, double * const g, const double gain) {
			/* add observed positive features */
			sparsevec sp_fv;
			m_fgen.create_all_pos_features_according_to_tree(inst, sp_fv, inst->cpostags);
			const double score = m_param.dot(sp_fv); // before update!

			sparsevec::const_iterator V_i = sp_fv.begin();
			const sparsevec::const_iterator V_end = sp_fv.end();
			for(; V_i != V_end; ++V_i) {
				const int id = V_i->first;
				const double val = V_i->second;
				assert(id < m_fgen.feature_dimentionality() && id >= 0);
				g[id] += val * gain;
			}
			return score;
		}

		double update_weights_or_gradients_with_gold_tree(const Instance *const inst, sparsevec &g, const double gain) {
			/* add observed positive features */
			sparsevec sp_fv;
			m_fgen.create_all_pos_features_according_to_tree(inst, sp_fv, inst->cpostags);
			const double score = m_param.dot(sp_fv); // before update!
			
			parameters::sparse_add(g, sp_fv, gain);
			return score;
		}

		void train_lbfgs_crfpp();
		void train_l2sgd();
		floatval_t l2sgd_calibration();

		// objection: min -(1/N) * \sum{logP(y|x)} + (\lambda / 2) * ||w||^2
		// sum_loss: -\sum{logP(y|x)} + C * ||w||^2   \lambda = 2C/N
		void l2sgd( const int N, 
			const floatval_t t0, 
			const floatval_t lambda, 
			const int num_epochs, 
			const bool calibration, 
			const int period, 
			const floatval_t epsilon);

		typedef struct {
			Parser *par;
		} lbfgs_internal_t;

		void parse(Decoder *decoder, Instance *inst, bool is_test) {
			const bool constrained = (!inst->constrained_tags_str.empty());
			if (constrained) m_fgen.create_constrained_tag_matrix(inst);

			m_fgen.create_all_feature_vectors(inst);
			compute_all_probs(inst);
			if (_mbr_decoding) {
				decoder->compute_marginals(inst, constrained);

				const int len = inst->size();
				/*if (inst->prob_unigram_a.empty()) {
					inst->prob_unigram_a.resize(len+1, m_fgen.tag_number_a());
				}
				if (inst->prob_unigram_b.empty()) {
					inst->prob_unigram_b.resize(len+1, m_fgen.tag_number_b());
				}
				*/
				if (inst->prob_unigram_joint.empty()) {
					inst->prob_unigram_joint.resize(len+1, m_fgen.tag_number());
				}
				decoder->use_marginal_as_arc_score(inst);

				get_best_tag_seq(inst, 0, inst->predicted_tags_joint);
				//get_best_tag_seq(inst, 1, inst->predicted_tags_a);
				//get_best_tag_seq(inst, 2, inst->predicted_tags_b);
				
				if (_test_tag_filter) filter_tag(inst, 0);
			} else {
				assert(false);
				decoder->decodeInterface(inst, constrained);
				m_fgen.assign_predicted_tag_str(inst);
				if(_verify_decoding_algorithm) verify_decoding_algorithm(inst);
			}
			
			inst->predicted_fv.clear();
			m_fgen.dealloc_fvec_prob(inst);
		}
		
		const char * pos_id_2_str(const int joint_a_b, const int t) {
			return (joint_a_b == 0 ? m_fgen.pos_id_2_str(t) : 
					(joint_a_b == 1 ? m_fgen.pos_id_2_str_a(t) : 
					 m_fgen.pos_id_2_str_b(t)));
		}
		
	//	static void assign_1_best_tag_seq(Instance *inst);
		void get_best_tag_seq( Instance *inst, const int joint_a_b, vector<string> &predicted_tags ); // 0-joint,1-a,2-b
		void filter_tag( Instance *inst, const int joint_a_b ); // 0-joint,1-a,2-b

		void test(const int iter);

		Instance *get_instance(const int inst_idx) {
			const int real_inst_idx = _inst_idx_to_read[inst_idx];
			if (real_inst_idx < m_pipe_train.getInstanceNum()) 
				return m_pipe_train.getInstance(real_inst_idx);
			else if (real_inst_idx < m_pipe_train.getInstanceNum() + m_pipe_train2.getInstanceNum())
				return m_pipe_train2.getInstance(real_inst_idx - m_pipe_train.getInstanceNum());
			else 
				return m_pipe_train3.getInstance(real_inst_idx - m_pipe_train.getInstanceNum() - m_pipe_train2.getInstanceNum());
		}

		void delete_one_train_instance_after_update_gradient(Instance *&inst) {
			if (inst->id < m_pipe_train.getInstanceNum()) {
				if (m_pipe_train.use_instances_posi()) {
					delete inst;
					inst = 0;
				}
			}
			else if (inst->id < m_pipe_train.getInstanceNum() + m_pipe_train2.getInstanceNum()) {
				if (m_pipe_train2.use_instances_posi()) {
					delete inst;
					inst = 0;
				}
			} else {
				if (m_pipe_train3.use_instances_posi()) {
					delete inst;
					inst = 0;
				}
			}
		}

		int get_inst_num_one_iter() const { return _inst_idx_to_read.size(); }
		void prepare_train_instances();
		void pre_train();
		void post_train() {
			m_pipe_train.dealloc_instance();
			m_pipe_train.closeInputFile();
			if (_use_train2) {
				m_pipe_train2.dealloc_instance();
				m_pipe_train2.closeInputFile();
			}
			if (_use_train3) {
				m_pipe_train3.dealloc_instance();
				m_pipe_train3.closeInputFile();
			}
			m_pipe_dev.dealloc_instance();
			if (_use_dev2) m_pipe_dev2.dealloc_instance();
		}
		void delete_candidate_heads(IOPipe &pipe) {
			for (int i = 0; i < pipe.getInstanceNum(); ++i) {
				Instance *inst = pipe.getInstance(i);
				if (!inst->constrained_tags.empty()) inst->constrained_tags.dealloc();
			}
		}

		void evaluate(IOPipe &pipe, const bool is_test);
		void reset_evaluate_metrics();
		void output_evaluate_metrics();

		void create_dictionaries(IOPipe &pipe, const bool collect_word);

		void load_dictionaries() {
			m_fgen.load_dictionaries(_dictionary_path);
			Decoder::T = m_fgen.tag_number();
			Decoder::pos_id_dummy = m_fgen.get_pos_id(DUMMY_CPOSTAG);
		}

		void save_dictionaries() {
			m_fgen.save_dictionaries(_dictionary_path);
		}

		void save_parameters(const int iter) {
			m_param.save(_parameter_path, iter);
		}

		void load_parameters(const int iter) {
			m_param.load(_parameter_path, iter);
		}

		void delete_parameters(const int iter) {
			m_param.delete_file(_parameter_path, iter);
		}

		void dot_all(const fvec * const fs, double * const probs, const int sz) const;

		void compute_all_probs(Instance *inst) const;

		void verify_decoding_algorithm( Instance * const inst);

		void evaluate_one_instance(const Instance * const inst);

		void objective_and_gradients_batch(Parser *par, const floatval_t * const w, floatval_t &f, floatval_t * const g, const int n);
		static void update_gradient(floatval_t *g, const fvec &fv, const double marg, const int n);
		static void update_gradient_one_inst(Parser *par, Decoder *decoder, const Instance *inst, double *g, const double gain);

		static void update_gradient(sparsevec &g, const fvec &fv, const double marg, const int n);
		static void update_gradient_one_inst(Parser *par, Decoder *decoder, const Instance *inst, sparsevec &g, const double gain);
/*
		void eval_oov_pos( const Instance *inst, int &oov_num, int &oov_pos_correct_num )
		{
			assert(false);
			for (int i = 1; i < inst->size(); ++i)
				if ( m_fgen.get_word_id(inst->forms[i]) < 0 ) {
					++oov_num;
					if (inst->cpostags[i] == inst->predicted_str_joint[i])
						++oov_pos_correct_num;
				}
		}

		int error_num_pos( const Instance *inst ) const
		{
			int error_num = 0;
			for (int i = 1; i < inst->size(); ++i) {
				const string &gold = inst->cpostags[i];
				const string &sys = inst->predicted_postags[i];
				vector<string> vecgold;
				vector<string> vecsys;
				simpleTokenize(gold, vecgold, "^");
				simpleTokenize(sys, vecsys, "^");
				assert(vecgold.size() <= 2 && vecsys.size() <= 2);
				if (vecgold.size() == 2) {
					assert(vecsys.size() == 2);
					if (vecgold[0] == "*") {
						if (vecgold[1] != vecsys[1]) ++error_num;
					} else if (vecgold[1] == "*") {
						if (vecgold[0] != vecsys[0]) ++error_num;
					} else {
						if (gold != sys) ++error_num;
					}
				} else {
					if (gold != sys) ++error_num;
				}
			}

			return error_num;
		}
		*/
	};
}


#endif

